import os
import torch
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from glob import glob
import pandas as pd
import itertools
from utils import preprocess, compute_mmd

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Load the feature extractor (encoder)
encoder = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50').to(device)
encoder.eval()

# Function to extract features in batches
def extract_features_batch(image_paths, feature_extractor, batch_size=16):
    features = []
    for i in tqdm(range(0, len(image_paths), batch_size), desc="Extracting Features"):
        batch_paths = image_paths[i:i+batch_size]
        imgs = [preprocess(Image.open(p).convert('RGB')) for p in batch_paths]
        imgs_tensor = torch.stack(imgs).to(device)
        with torch.no_grad():
            batch_features = feature_extractor(imgs_tensor)
        features.append(batch_features.cpu())
    if features:
        return torch.cat(features, dim=0)
    else:
        return torch.tensor([])

# Load reference features
reference_features_path = 'reference_features.pt'
reference_features = torch.load(reference_features_path)
# Move tensors to the desired device
for key in reference_features:
    reference_features[key] = reference_features[key].to(device)

# Function to get generated image paths
def get_generated_image_paths(base_output_dir, lambda_reg, lambda_2, gender, profession):
    dir_name = f'resnet_lambda_{lambda_reg}_lambda2_{lambda_2}_guidance_7.5'
    output_dir = os.path.join(base_output_dir, dir_name)
    pattern = os.path.join(output_dir, f'*_{gender}_{profession}_*.png')
    image_paths = glob(pattern)
    return image_paths

# Parameters
lambda_values = [20]
lambda2s = [1]
genders = ['male', 'female']
professions = ['nurse', 'firefighter']
base_output_dir = 'generation_results'
# Output directory for embeddings
embedding_output_dir = 'saved_embeddings'
os.makedirs(embedding_output_dir, exist_ok=True)

# Initialize a list to collect results
results = []

# Iterate over all combinations
for lambda_reg, lambda_2 in tqdm(itertools.product(lambda_values, lambda2s), desc='Lambda Combinations', total=len(lambda_values)*len(lambda2s)):
    # Handle special case when lambda_reg == 0
    if lambda_reg != 20 and lambda_2 != 1:
        if lambda_reg == 0:
            pass
        else:
            continue
    if lambda_reg == 0 and lambda_2 != 0:
        continue

    for gender in genders:
        for profession in professions:
            print(f'Processing lambda_reg: {lambda_reg}, lambda_2: {lambda_2}, gender: {gender}, profession: {profession}')
            
            # Get generated image paths
            image_paths = get_generated_image_paths(base_output_dir, lambda_reg, lambda_2, gender, profession)
            if not image_paths:
                print('No images found for this combination.')
                continue  # Skip if no images found
            
            # Extract features from generated images
            gen_features = extract_features_batch(image_paths, encoder).to(device)

            # Save the embeddings for this lambda_reg, lambda_2, gender, and profession
            embedding_save_path = os.path.join(
                embedding_output_dir,
                f'embeddings_lambda_reg_{lambda_reg}_lambda_2_{lambda_2}_{gender}_{profession}.pt'
            )
            torch.save(gen_features.cpu(), embedding_save_path)
            print(f'Saved embeddings to {embedding_save_path}')
            
            # Get reference features
            key = (gender, profession)
            if key not in reference_features:
                print(f'No reference features found for {key}.')
                continue
            ref_features = reference_features[key].to(device)
            
            # Compute MMD for this gender-profession combination
            mmd_value = compute_mmd(gen_features, ref_features).item()
            print(f'MMD: {mmd_value}')
            
            # Collect gender-specific result
            results.append({
                'lambda_reg': lambda_reg,
                'lambda_2': lambda_2,
                'gender': gender,
                'profession': profession,
                'MMD': mmd_value
            })
    
    # Additional code for combined 'all' embeddings can be added here if needed.

# Create a DataFrame
df_results = pd.DataFrame(results)

# Save results to CSV
df_results.to_csv('image_quality_results.csv', index=False)
print('Results saved to image_quality_results.csv')
